{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Finding Synonyms and Analogies\n",
    "\n",
    "This notebook is taken from a [PyTorch NLP tutorial](https://github.com/joosthub/pytorch-nlp-tutorial-ny2018/blob/master/day_1/0_Using_Pretrained_Embeddings.ipynb) source: [repository for the training tutorial as the 2018 O'Reilly AI Conference in NYC on April 29 and 30, 2018](https://github.com/joosthub/pytorch-nlp-tutorial-ny2018)\n",
    "\n",
    "Another related sources [chapter 10.6](http://www.d2l.ai/chapter_natural-language-processing/similarity-analogy.html) of [Dive into Deep Learning](http://www.d2l.ai/index.html) and [chapter 10.5](http://www.d2l.ai/chapter_natural-language-processing/glove.html) explaining Word Embedding with Global Vectors (GloVe)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# you might need to install annoy with the following command:\n",
    "# conda install -c conda-forge python-annoy\n",
    "from annoy import AnnoyIndex\n",
    "import numpy as np\n",
    "import torch\n",
    "from tqdm import tqdm_notebook\n",
    "from argparse import Namespace"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Glove embeddings can be downloaded from [GloVe webpage](https://nlp.stanford.edu/projects/glove/).\n",
    "\n",
    "Given the sizes of the files, it is better not to dowload it through WIFI!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# this need to be modified!\n",
    "args = Namespace(\n",
    "    glove_filename='/home/lelarge/data/glove/glove.6B.100d.txt'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def load_word_vectors(filename):\n",
    "    word_to_index = {}\n",
    "    word_vectors = []\n",
    "    \n",
    "    with open(filename) as fp:\n",
    "        for line in tqdm_notebook(fp.readlines(), leave=False):\n",
    "            line = line.split(\" \")\n",
    "            \n",
    "            word = line[0]\n",
    "            word_to_index[word] = len(word_to_index)\n",
    "            \n",
    "            vec = np.array([float(x) for x in line[1:]])\n",
    "            word_vectors.append(vec)\n",
    "            \n",
    "    return word_to_index, word_vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=400000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "word_to_index, word_vectors = load_word_vectors(args.glove_filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "400000"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(word_vectors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100,)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "word_vectors[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3366"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "word_to_index['beautiful']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class PreTrainedEmbeddings(object):\n",
    "    def __init__(self, glove_filename):\n",
    "        self.word_to_index, self.word_vectors = load_word_vectors(glove_filename)\n",
    "        self.word_vector_size = len(self.word_vectors[0])\n",
    "        \n",
    "        self.index_to_word = {v: k for k, v in self.word_to_index.items()}\n",
    "        self.index = AnnoyIndex(self.word_vector_size, metric='euclidean')\n",
    "        print('Building Index')\n",
    "        for _, i in tqdm_notebook(self.word_to_index.items(), leave=False):\n",
    "            self.index.add_item(i, self.word_vectors[i])\n",
    "        self.index.build(50)\n",
    "        print('Finished!')\n",
    "    \n",
    "    def get_embedding(self, word):\n",
    "        return self.word_vectors[self.word_to_index[word]]\n",
    "    \n",
    "    def closest(self, word, n=1):\n",
    "        vector = self.get_embedding(word)\n",
    "        nn_indices = self.index.get_nns_by_vector(vector, n)\n",
    "        return [self.index_to_word[neighbor] for neighbor in nn_indices]\n",
    "    \n",
    "    def closest_v(self, vector, n=1):\n",
    "        nn_indices = self.index.get_nns_by_vector(vector, n)\n",
    "        return [self.index_to_word[neighbor] for neighbor in nn_indices]\n",
    "    \n",
    "    def sim(self, w1, w2):\n",
    "        return np.dot(self.get_embedding(w1), self.get_embedding(w2))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=400000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Building Index\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=400000), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished!\n"
     ]
    }
   ],
   "source": [
    "glove = PreTrainedEmbeddings(args.glove_filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['apple', 'microsoft', 'pc', 'software', 'google']"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glove.closest('apple', n=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['chip', 'chips', 'semiconductor', 'intel', 'makers']"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glove.closest('chip', n=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['baby', 'boy', 'girl', 'mom', 'child']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glove.closest('baby', n=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['beautiful', 'lovely', 'wonderful', 'brilliant', 'handsome']"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glove.closest('beautiful', n=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def SAT_analogy(w1, w2, w3):\n",
    "    '''\n",
    "    Solves problems of the type:\n",
    "    w1 : w2 :: w3 : __\n",
    "    '''\n",
    "    closest_words = []\n",
    "    try:\n",
    "        w1v = glove.get_embedding(w1)\n",
    "        w2v = glove.get_embedding(w2)\n",
    "        w3v = glove.get_embedding(w3)\n",
    "        w4v = w3v + (w2v - w1v)\n",
    "        closest_words = glove.closest_v(w4v, n=5)\n",
    "        closest_words = [w for w in closest_words if w not in [w1, w2, w3]]\n",
    "    except:\n",
    "        pass\n",
    "    if len(closest_words) == 0:\n",
    "        print(':-(')\n",
    "    else:\n",
    "        print('{} : {} :: {} : {}'.format(w1, w2, w3, closest_words[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "man : he :: woman : she\n"
     ]
    }
   ],
   "source": [
    "SAT_analogy('man', 'he', 'woman')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fly : plane :: sail : ship\n"
     ]
    }
   ],
   "source": [
    "SAT_analogy('fly', 'plane', 'sail')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "beijing : china :: tokyo : japan\n"
     ]
    }
   ],
   "source": [
    "SAT_analogy('beijing', 'china', 'tokyo')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "man : woman :: son : daughter\n"
     ]
    }
   ],
   "source": [
    "SAT_analogy('man', 'woman', 'son')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "man : leader :: woman : opposition\n"
     ]
    }
   ],
   "source": [
    "SAT_analogy('man', 'leader', 'woman')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "woman : leader :: man : leaders\n"
     ]
    }
   ],
   "source": [
    "SAT_analogy('woman', 'leader', 'man')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
